{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "f11db778",
   "metadata": {},
   "source": [
    "\n",
    "这里以中文BERT为例，实现提及聚类："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "0c9861b9",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "<>:3: SyntaxWarning: invalid escape sequence '\\h'\n",
      "<>:4: SyntaxWarning: invalid escape sequence '\\h'\n",
      "<>:3: SyntaxWarning: invalid escape sequence '\\h'\n",
      "<>:4: SyntaxWarning: invalid escape sequence '\\h'\n",
      "C:\\Users\\spring\\AppData\\Local\\Temp\\ipykernel_22400\\928338747.py:3: SyntaxWarning: invalid escape sequence '\\h'\n",
      "  tokenizer = AutoTokenizer.from_pretrained(\"E:\\huggingface\\hub\\models--bert-base-chinese\")\n",
      "C:\\Users\\spring\\AppData\\Local\\Temp\\ipykernel_22400\\928338747.py:4: SyntaxWarning: invalid escape sequence '\\h'\n",
      "  model = AutoModel.from_pretrained(\"E:\\huggingface\\hub\\models--bert-base-chinese\")\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['[CLS]', '小', '明', '给', '小', '红', '一', '束', '花', '，', '她', '很', '高', '兴', '。', '[SEP]']\n",
      "[101, 2207, 3209, 5314, 2207, 5273, 671, 3338, 5709, 8024, 1961, 2523, 7770, 1069, 511, 102]\n",
      "torch.Size([1, 16, 768])\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "from transformers import AutoTokenizer, AutoModel\n",
    "tokenizer = AutoTokenizer.from_pretrained(\"E:\\huggingface\\hub\\models--bert-base-chinese\")\n",
    "model = AutoModel.from_pretrained(\"E:\\huggingface\\hub\\models--bert-base-chinese\")\n",
    "\n",
    "# 进行分词\n",
    "sentence=\"小明给小红一束花，她很高兴。\"\n",
    "subtokenized_sentence=tokenizer.tokenize(sentence)\n",
    "subtokenized_sentence = [tokenizer._cls_token] + \\\n",
    "    subtokenized_sentence + [tokenizer._sep_token]\n",
    "subtoken_ids_sentence = tokenizer.convert_tokens_to_ids(\\\n",
    "    subtokenized_sentence)\n",
    "print(subtokenized_sentence)\n",
    "print(subtoken_ids_sentence)\n",
    "\n",
    "# 计算对应的特征\n",
    "outputs = model(torch.Tensor(subtoken_ids_sentence).\\\n",
    "    unsqueeze(0).long())\n",
    "hidden_states = outputs[0]\n",
    "print(hidden_states.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "49d7247c",
   "metadata": {},
   "source": [
    "假设已经通过提及检测模型找到了句子中的提及，这里用每个提及的第一个子词（在中文中也就是第一个字）作为词特征："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "60d1eedf",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([4, 768])\n"
     ]
    }
   ],
   "source": [
    "# 提及的跨度，假设(-1,0)表示[CLS]的跨度，用于表示空提及[NA]，\n",
    "# 在实际训练中也可以额外定义个空提及符号\n",
    "mention_spans = [(-1,0),(0,2),(3,5),(10,11)]\n",
    "word_features = torch.concat([hidden_states[0,x+1].unsqueeze(0)\\\n",
    "    for (x,y) in mention_spans],0)\n",
    "print(word_features.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fd5d787f",
   "metadata": {},
   "source": [
    "首先，通过双仿射函数计算打分。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "ca991cb8",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[[     -inf,      -inf,      -inf,      -inf],\n",
      "         [-494.1196,      -inf,      -inf,      -inf],\n",
      "         [-684.3828, -726.1643,      -inf,      -inf],\n",
      "         [-318.4721, -654.2186, -698.3372,      -inf]]],\n",
      "       grad_fn=<AddBackward0>)\n",
      "tensor([[0, 0, 0]])\n"
     ]
    }
   ],
   "source": [
    "import sys\n",
    "sys.path.append('./code')\n",
    "from my_utils import Biaffine\n",
    "biaffine = Biaffine(word_features.shape[1])\n",
    "\n",
    "# 对word_features进行打分\n",
    "scores = biaffine(word_features.unsqueeze(0),\\\n",
    "    word_features.unsqueeze(0))\n",
    "# 由于只关注当前提及之前的提及是否与其进行共指，\n",
    "# 因此将它转换为下三角函数，并且为上三角部分置为负无穷：\n",
    "scores = scores.tril(diagonal=-1)\n",
    "inf_mask = torch.zeros_like(scores)-torch.inf\n",
    "inf_mask = inf_mask.triu()\n",
    "scores += inf_mask\n",
    "print(scores)\n",
    "print(scores.argmax(-1)[:,1:])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "de88f8d7",
   "metadata": {},
   "source": [
    "由于模型未经过训练，因此仅通过双仿射函数初始化获得结构显然是错误的。我们可以训练模型，计算损失函数计算方式如下："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "9f2abedd",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor(111.9155, grad_fn=<NllLossBackward0>)\n"
     ]
    }
   ],
   "source": [
    "# 只计算除了[NA]以外的提及的损失\n",
    "target = torch.Tensor([0,0,1]).long()\n",
    "loss_func = torch.nn.NLLLoss()\n",
    "loss = loss_func(torch.nn.functional.log_softmax(scores[:,1:].\\\n",
    "    squeeze(0),-1),target)\n",
    "print(loss)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "880e61fb",
   "metadata": {},
   "source": [
    "接下来通过点积计算打分。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "f6a04704",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[    -inf,     -inf,     -inf,     -inf],\n",
      "        [235.2012,     -inf,     -inf,     -inf],\n",
      "        [188.3145, 267.1165,     -inf,     -inf],\n",
      "        [221.3709, 101.3910, 292.7802,     -inf]], grad_fn=<AddBackward0>)\n",
      "tensor([0, 1, 2])\n"
     ]
    }
   ],
   "source": [
    "scores2 = torch.matmul(word_features,word_features.T)\n",
    "scores2 = scores2.tril(diagonal=-1)\n",
    "inf_mask = torch.zeros_like(scores2)-torch.inf\n",
    "inf_mask = inf_mask.triu()\n",
    "scores2 += inf_mask\n",
    "print(scores2)\n",
    "print(scores2.argmax(-1)[1:])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "842f4689-f832-4218-9910-3f693cefb278",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "pytorch",
   "language": "python",
   "name": "env_name"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.21"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
